# +~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~ #  
#
#' @title  Induce labelings from crowd-sourced elite criticism codings for 
#'          tweets distributed for annotation in first and second round of
#'          crowd coding
#' @author Hauke Licht
#
# +~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~ #

# setup ----

set.seed(1234)

# load packages
library(readr)
library(dplyr)
library(tidyr)
library(purrr)

library(ggplot2)
library(ggridges)

library(AnnotationModelsR) # renv::install("haukelicht/AnnotationModelsR@0.1.2")

base_path <- file.path(".")
data_path <- file.path(base_path, "data")
fits_path <- file.path(data_path, "fits")
codings_path <- file.path(data_path, "intermediate", "codings")
labelings_path <- file.path(data_path, "intermediate", "labelings")

helpers_path <- file.path(base_path, "code", "helpers")
source(file.path(helpers_path, "utils.R"))

# load codings data ----

codings <- bind_rows(
  read_rds(file.path(codings_path, "codings_sample_1.rds"))
  , read_rds(file.path(codings_path, "codings_sample_2.rds"))
  , .id = "sample"
)

# clean codings ----

# label class mappings 
label_abbr <- c(
  "yes-general" = '"General"'
  , "yes-specific" = '"Specific"'
  , "yes-unsure" = '"Ambigious"'
  , "no" = '"No"'
  , "cannot-answer" = '"Cannot answer"'
)

# determine valid codings
codings <- mutate(codings, coding = ifelse(coding %in% names(label_abbr), coding, NA_character_))

count(codings, coding)

# compute and inspect coder stats ----

coder_stats <- codings %>% 
  group_by(worker_id) %>% 
  summarize(
    n_judgments = n_distinct(item_id)
    , judgment_entropy = entropy(coding)
    , judgment_set = list(rev(sort(table(coding))))
    , mean_duration = mean(duration)
    , median_duration = median(duration)
    , q10_duration = quantile(duration, .10)
    , q25_duration = quantile(duration, .25)
    , q75_duration = quantile(duration, .75)
    , q90_duration = quantile(duration, .90)
    , badbadnotgood = case_when(
      n_judgments <= 5 ~ unique(worker_id)
      , n_judgments > 5 & judgment_entropy < .25  ~ unique(worker_id)
      , n_judgments > 50 & judgment_entropy < .5  ~ unique(worker_id)
      , n_judgments > 100 & judgment_entropy < 1  ~ unique(worker_id)
      , TRUE ~ NA_character_
    ) 
  ) %>% 
  ungroup()

# number of coders in/across samples 
J <- codings %>% 
  group_by(sample) %>% 
  summarise(n = n_distinct(worker_id)) %>% 
  {setNames(.$n, .$sample)}
J["both"] <- length(with(codings, unique(worker_id)))

# compute and inspect tweet stats -----

tweet_stats <- codings %>% 
  group_by(status_id, source, sample) %>% 
  summarize(
    median_duration = median(duration)
    , mean_duration = mean(duration)
    , n_judgments = n_distinct(worker_id)
    , judgment_entropy = entropy(coding)
    , judgments = list(rev(sort(table(coding))))
    , mode_label = map_chr(map(judgments, names), 1)
    , mode_label_n = map_int(judgments, 1)
    , mode_is_tie = map(judgments, 2)
    , mode_is_tie = ifelse(map_lgl(mode_is_tie, is.null), 0L, map_int(mode_is_tie, 1))
    , mode_is_tie = mode_label_n == mode_is_tie
  ) %>% 
  ungroup()

# filter bad coders ----

bad_coders <- filter(coder_stats, !is.na(badbadnotgood))$worker_id
n_bad_workers <- length(bad_coders)

n_bad_removed <- codings %>% 
  filter(worker_id %in% bad_coders) %>% 
  nrow()

# visualize
coder_stats %>%   
  ggplot(aes(n_judgments, judgment_entropy, color = !is.na(badbadnotgood), shape = !is.na(badbadnotgood))) +
    geom_jitter(alpha = .9, size = 1, width = .01, height = .01) +
    scale_x_continuous(trans = "log10") + 
    scale_color_manual(breaks = c(F, T), values = c("black", "red")) +
    scale_shape_manual(breaks = c(F, T), values = c(1, 3)) +
    guides(shape = FALSE) + 
    labs(
      title = "Number of judgments against coder judgment entropy"
      , subtitle = "Vertical and horizontal jitter of max. 1% added to avoid over-plotting"
      , x = "Number of judgments (base-10 log scale)"
      , y = "Coder judgment entropy"
      , color = "Contributions removed"
    ) + 
    theme(legend.position = "bottom")

# remove ... 
tmp <- codings %>% 
  filter(
    # invalid codings
    !is.na(coding)
    # "spam" coders codings
    , !worker_id %in% bad_coders
    # codings made in less than 4 seconds
    , duration >= 4
  )

majority_ca_tweets <- tmp %>% 
  group_by(status_id, sample) %>% 
  summarise(
    judgments = list(rev(sort(table(coding))))
    , mode_label = map_chr(map(judgments, names), 1)
  ) %>% 
  ungroup() %>% 
  filter(mode_label == "cannot-answer")

n_majority_ca_tweets <- nrow(majority_ca_tweets)  

# cleaned codings: also remove
codings_cleaned <- tmp %>% 
  # ... majority cannot-answer tweets
  anti_join(majority_ca_tweets) %>% 
  # all other "Cannot answer" codings
  filter(coding != "cannot-answer")

# label distribution in remaining codings
count(codings_cleaned, coding)

# distribution of No. judgments in remaining codings
codings_cleaned %>% 
  group_by(item_id) %>%
  summarise(n_judgments = n()) %>% 
  count(n_judgments)

# recompute tweet stats ----

tweet_stats <- codings_cleaned %>% 
  group_by(status_id, source, sample) %>% 
  summarize(
    median_duration = median(duration)
    , mean_duration = mean(duration)
    , n_judgments = n_distinct(worker_id)
    , judgment_entropy = entropy(coding)
    # sorted table (first cell has lowest count), reversed (first cell has highest count)
    , judgments = list(rev(sort(table(coding))))
    # get the first cell's entry from the judgments table
    , mode_label = map_chr(map(judgments, names), 1)
    # how high is the count of the most frequent judgment? 
    , mode_label_n = map_int(judgments, 1)
    , mode_is_tie = map(judgments, 2)
    , mode_is_tie = ifelse(map_lgl(mode_is_tie, is.null), 0L, map_int(mode_is_tie, 1))
    , mode_is_tie = mode_label_n == mode_is_tie
  ) %>% 
  ungroup()

# fit models ----

compute_label_props <- function(x, samples) {
  as.array(prop.table(table(x$coding[x$sample %in% samples])))
}

# sample 1 only 
em_fit_s1 <- em(
  codings_cleaned[codings_cleaned$sample == "1", ]
  , item.col = status_id
  , annotator.col = worker_id
  , label.col = coding
  , max.iters = 500
  , .prevalence.prior = compute_label_props(codings_cleaned, "1")
  , .min.relative.diff = 0.0001
)

fp <- file.path(fits_path, "dawidskene_model_sample_1.rds")
if (!file.exists(fp))
  write_rds(em_fit_s1, fp)

# sample 2 only
em_fit_s2 <- em(
  codings_cleaned[codings_cleaned$sample == "2", ]
  , item.col = status_id
  , annotator.col = worker_id
  , label.col = coding
  , max.iters = 500
  , .prevalence.prior = compute_label_props(codings_cleaned, "2")
  , .min.relative.diff = 0.0001
)

fp <- file.path(fits_path, "dawidskene_model_sample_2.rds")
if (!file.exists(fp))
  write_rds(em_fit_s2, fp)

# both samples
em_fit <- em(
  codings_cleaned
  , item.col = status_id
  , annotator.col = worker_id
  , label.col = coding
  , max.iters = 500
  , .prevalence.prior = compute_label_props(codings_cleaned, c("1", "2"))
  , .min.relative.diff = 0.0001
)

fp <- file.path(fits_path, "dawidskene_model_pooled_samples.rds")
if (!file.exists(fp))
  write_rds(em_fit, fp)

# save labeling ----

labelings <- em_fit$est_class_probs

labeled_tweets <- codings_cleaned %>% 
  select(
    item_id
    , country_iso3c, party_id, party_name_short
    , user_id, status_id
    , text_en = `source`
    , prob_political, cluster_id, pred_label
    , sample
  ) %>% 
  unique() %>% 
  left_join(
    select(tweet_stats, sample, status_id, median_duration, judgment_entropy, n_judgments)
    , by = c("sample", "status_id")
  ) %>% 
  right_join(labelings, by = "status_id")

fp <- file.path(labelings_path, "dawidskene_labelings_pooled_samples.csv")
if (!file.exists(fp))
  write_csv(labeled_tweets, fp)
